Skip to content

Add dataset type olmo_grain for AI2 OLMo numpy pretrain mixes#3749

Open
gagika wants to merge 1 commit intomainfrom
gagik-olmo-data
Open

Add dataset type olmo_grain for AI2 OLMo numpy pretrain mixes#3749
gagika wants to merge 1 commit intomainfrom
gagik-olmo-data

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented Apr 26, 2026

Description

  • Adds dataset_type=olmo_grain, a Grain-based input pipeline for AI2's
    pre-tokenized OLMo numpy mixes (e.g. OLMo-mix-0925-official.txt).
    Reads headerless .npy token streams from a gcsfuse mount, applies
    OLMo-core's repeated-n-gram filter, and yields the shapes the MaxText
    pretrain trainer expects.
  • Stateless sampler: record at step k is a pure function of
    (seed, shard, k). Resume reads the latest step from
    config.checkpoint_dir and shifts the sampler — no Grain iterator
    state in the checkpoint.
  • Ships two data tools (download_olmo_data_to_gcs.py with HTTP-Range
    resume; build_olmo_npy_index.py for header-scan indexing) and two
    launchers (run_olmo3_7b_grain_smoke.sh,
    run_olmo3_7b_grain_resume_test.sh).

Tests

  • Unit tests pass (tests/unit/input_pipeline/olmo_*)
  • Smoke train: 50 steps, loss 11.99 → 8.93 on v4-8 (4-layer bf16)
  • Resume test: Run B picks up at step 50 with loss continuity 8.931 → 8.930

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

authored-by: @aireenmei

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 26, 2026

Comment thread tests/unit/input_pipeline/olmo_data_grain_test.py Outdated
Comment thread src/maxtext/input_pipeline/olmo_data_grain.py Outdated
Comment thread src/maxtext/input_pipeline/olmo_data_grain.py Outdated
Comment thread src/maxtext/input_pipeline/olmo_data_grain.py
@gagika gagika force-pushed the gagik-olmo-data branch from 2c8507c to 9de5321 Compare May 2, 2026 18:47
@gagika gagika force-pushed the gagik-olmo-data branch from 9de5321 to 9e3ff8f Compare May 4, 2026 14:22
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a high-quality, Grain-based input pipeline for AI2's OLMo numpy datasets. The implementation is robust, well-documented, and includes a particularly clean approach to stateless resumption by deriving the data offset from the model checkpoint step.

🔍 General Feedback

  • Stateless Resume: The initial_step logic in the sampler is an excellent design choice that avoids the complexities of Grain iterator-state serialization.
  • N-gram Filtering: The integration of OLMo-core's repetition filter via a custom transform that masks instances in the loss is both efficient and sharding-friendly.
  • Testing and Validation: The inclusion of unit tests, smoke scripts, and end-to-end resume tests provides great confidence in the stability of the new pipeline.
  • Performance: While the in-memory permutation for shuffling is currently manageable, it's worth monitoring as dataset sizes scale further.

reverse = _find_end_first_consecutive_true(arr[::-1])
return len(arr) - reverse if reverse > 0 else -1


Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The slicing prog[:-1:] is functional but non-idiomatic in Python. Using prog[:-1] is cleaner and more conventional.

Suggested change
true_locs = np.where(prog[:-1] == prog[1:])[0]


max_period = min(max_period, len(arr) // 3)

for period in range(min_period, max_period + 1):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Using a mask_value of -1 to pad rows before reshaping is a clever way to avoid periodic matches wrapping around rows. However, if -1 (or its uint32 representation 0xFFFFFFFF) is a valid token in the dataset, the ValueError will be triggered unnecessarily.

Consider using a value that is strictly outside the tokenizer's vocabulary range if available, or documenting that -1 is reserved.

total_instances: ``index.total_instances`` from the OLMo index.
seed: Base seed for the shuffle.
shard_index: Zero-based index of this data-loading host. Typically
``jax.process_index()``.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 For very large datasets (e.g., the 724M instance mix mentioned), allocating the full permutation in host memory (~5.8 GB) can be a significant spike, especially if many hosts are doing it simultaneously at an epoch boundary.

While acceptable for the current scope, consider implementing a lazy or on-disk permutation scheme if the dataset size grows further or if host memory becomes a constraint.

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change!

Have you had a chance to chat with @aireenmei about this yet? I'm wondering if we could design these features to directly leverage the existing grain_data_processing.py. For example, we could add things like n-gram filtering and pre-tokenized numpy mixes as features there to improve code reusability.

The main benefit I see is reducing maintenance overhead. We currently maintain both tfds and c4_tfds_mlperf, but the latter is rarely used and has some maintenance issues. Since Grain will be heavily used moving forward, it makes sense to build on top of it directly. Let me know your thoughts—happy to discuss!


```bash
python tools/data_generation/download_olmo_data_to_gcs.py \
--mix-file /path/to/OLMo-mix-0925-official.txt \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this path be at GCS or not necessary?

"""

olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.")
olmo_path_remap_from: PathStr = Field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add some comments and indicate why we need this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants